import numpy as np

def check_logits(results):
    # finds out if there are mask logits already computed
    for k in results.mask.keys():
        if "mask_logits" in results.mask[k]:
            return results.mask[k].mask_logits
    return None

def get_valid(valid, indices):
    return np.sum(valid[...,indices], axis=-1)


def compute_distributional_distance(dist_type, result, comp_result, flags=None, cflags=None):
    # TODO: add  other distributional distances
    if dist_type == "target_likelihood":
        if cflags is not None: log_probs = comp_result.log_probs[cflags]
        return (log_probs.mean(dim=-1).unsqueeze(-1)).abs()
    if dist_type == "likelihood":
        log_probs, clog_probs = result.log_probs, comp_result.log_probs 
        if flags is not None: log_probs= result.log_probs[flags]
        if cflags is not None: clog_probs = comp_result.log_probs[cflags]
        return (log_probs.mean(dim=-1).unsqueeze(-1) - clog_probs.mean(dim=-1).unsqueeze(-1)).abs()
    elif dist_type == "mean": 
        params, cparams = result.params[0], comp_result.params[0]
        if flags is not None: params= result.params[0][flags]
        if cflags is not None: cparams = comp_result.params[0][cflags]
        return (params - cparams).abs().mean(dim=-1)
    elif dist_type == "was1": 
        params, cparams = result.params[0], comp_result.params[0]
        if flags is not None: params= result.params[0][flags]
        if cflags is not None: cparams = comp_result.params[0][cflags]
        params1, cparams1 = result.params[1], comp_result.params[1]
        if flags is not None: params= result.params[1][flags]
        if cflags is not None: cparams = comp_result.params[1][cflags]
        return (params - cparams).abs().mean(dim=-1) + (params1 - cparams1).abs().mean(dim=-1)
    elif dist_type == "was2": 
        params, cparams = result.params[0], comp_result.params[0]
        if flags is not None: params= result.params[0][flags]
        if cflags is not None: cparams = comp_result.params[0][cflags]
        params1, cparams1 = result.params[1], comp_result.params[1]
        if flags is not None: params= result.params[1][flags]
        if cflags is not None: cparams = comp_result.params[1][cflags]
        return (params - cparams).norm(dim=-1) + (params1 - cparams1).norm(dim=-1)